{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Train.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LuK-VZQCO9kD",
        "colab_type": "text"
      },
      "source": [
        "**Initialization**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nrDnv3enMXvY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install kaggle\n",
        "\n",
        "from google.colab import files\n",
        "\n",
        "uploaded = files.upload()\n",
        "\n",
        "# Upload your kaggle.json here...\n",
        "for fn in uploaded.keys():\n",
        "  print('User uploaded file \"{name}\" with length {length} bytes'.format(\n",
        "      name=fn, length=len(uploaded[fn])))\n",
        "\n",
        "!mkdir -p ~/.kaggle\n",
        "!cp kaggle.json ~/.kaggle\n",
        "\n",
        "!kaggle competitions download painter-by-numbers -f train.zip\n",
        "\n",
        "!wget http://images.cocodataset.org/zips/train2014.zip\n",
        "\n",
        "!unzip -q train.zip\n",
        "!unzip -q train2014.zip\n",
        "\n",
        "!rm train.zip\n",
        "!rm train2014.zip\n",
        "\n",
        "!pip3 install tqdm\n",
        "!pip3 install TensorboardX\n",
        "\n",
        "!mkdir experiments\n",
        "!mkdir logs"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yIWTzYh0PBgD",
        "colab_type": "text"
      },
      "source": [
        "**Training**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tlE8XKhAM_uS",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "\n",
        "\n",
        "def calc_mean_std(feat, eps=1e-5):\n",
        "    # eps is a small value added to the variance to avoid divide-by-zero.\n",
        "    size = feat.size()\n",
        "    assert (len(size) == 4)\n",
        "    N, C = size[:2]\n",
        "    feat_var = feat.view(N, C, -1).var(dim=2) + eps\n",
        "    feat_std = feat_var.sqrt().view(N, C, 1, 1)\n",
        "    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)\n",
        "    return feat_mean, feat_std\n",
        "\n",
        "def mean_variance_norm(feat):\n",
        "    size = feat.size()\n",
        "    mean, std = calc_mean_std(feat)\n",
        "    normalized_feat = (feat - mean.expand(size)) / std.expand(size)\n",
        "    return normalized_feat\n",
        "\n",
        "def _calc_feat_flatten_mean_std(feat):\n",
        "    # takes 3D feat (C, H, W), return mean and std of array within channels\n",
        "    assert (feat.size()[0] == 3)\n",
        "    assert (isinstance(feat, torch.FloatTensor))\n",
        "    feat_flatten = feat.view(3, -1)\n",
        "    mean = feat_flatten.mean(dim=-1, keepdim=True)\n",
        "    std = feat_flatten.std(dim=-1, keepdim=True)\n",
        "    return feat_flatten, mean, std\n",
        "\n",
        "import torch.nn as nn\n",
        "\n",
        "decoder = nn.Sequential(\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(512, 256, (3, 3)),\n",
        "    nn.ReLU(),\n",
        "    nn.Upsample(scale_factor=2, mode='nearest'),\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(256, 256, (3, 3)),\n",
        "    nn.ReLU(),\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(256, 256, (3, 3)),\n",
        "    nn.ReLU(),\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(256, 256, (3, 3)),\n",
        "    nn.ReLU(),\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(256, 128, (3, 3)),\n",
        "    nn.ReLU(),\n",
        "    nn.Upsample(scale_factor=2, mode='nearest'),\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(128, 128, (3, 3)),\n",
        "    nn.ReLU(),\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(128, 64, (3, 3)),\n",
        "    nn.ReLU(),\n",
        "    nn.Upsample(scale_factor=2, mode='nearest'),\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(64, 64, (3, 3)),\n",
        "    nn.ReLU(),\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(64, 3, (3, 3)),\n",
        ")\n",
        "\n",
        "vgg = nn.Sequential(\n",
        "    nn.Conv2d(3, 3, (1, 1)),\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(3, 64, (3, 3)),\n",
        "    nn.ReLU(),  # relu1-1\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(64, 64, (3, 3)),\n",
        "    nn.ReLU(),  # relu1-2\n",
        "    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(64, 128, (3, 3)),\n",
        "    nn.ReLU(),  # relu2-1\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(128, 128, (3, 3)),\n",
        "    nn.ReLU(),  # relu2-2\n",
        "    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(128, 256, (3, 3)),\n",
        "    nn.ReLU(),  # relu3-1\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(256, 256, (3, 3)),\n",
        "    nn.ReLU(),  # relu3-2\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(256, 256, (3, 3)),\n",
        "    nn.ReLU(),  # relu3-3\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(256, 256, (3, 3)),\n",
        "    nn.ReLU(),  # relu3-4\n",
        "    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(256, 512, (3, 3)),\n",
        "    nn.ReLU(),  # relu4-1, this is the last layer used\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(512, 512, (3, 3)),\n",
        "    nn.ReLU(),  # relu4-2\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(512, 512, (3, 3)),\n",
        "    nn.ReLU(),  # relu4-3\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(512, 512, (3, 3)),\n",
        "    nn.ReLU(),  # relu4-4\n",
        "    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(512, 512, (3, 3)),\n",
        "    nn.ReLU(),  # relu5-1\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(512, 512, (3, 3)),\n",
        "    nn.ReLU(),  # relu5-2\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(512, 512, (3, 3)),\n",
        "    nn.ReLU(),  # relu5-3\n",
        "    nn.ReflectionPad2d((1, 1, 1, 1)),\n",
        "    nn.Conv2d(512, 512, (3, 3)),\n",
        "    nn.ReLU()  # relu5-4\n",
        ")\n",
        "\n",
        "class SANet(nn.Module):\n",
        "    def __init__(self, in_planes):\n",
        "        super(SANet, self).__init__()\n",
        "        self.f = nn.Conv2d(in_planes, in_planes, (1, 1))\n",
        "        self.g = nn.Conv2d(in_planes, in_planes, (1, 1))\n",
        "        self.h = nn.Conv2d(in_planes, in_planes, (1, 1))\n",
        "        self.sm = nn.Softmax(dim = -1)\n",
        "        self.out_conv = nn.Conv2d(in_planes, in_planes, (1, 1))\n",
        "    def forward(self, content, style):\n",
        "        F = self.f(mean_variance_norm(content))\n",
        "        G = self.g(mean_variance_norm(style))\n",
        "        H = self.h(style)\n",
        "        b, c, h, w = F.size()\n",
        "        F = F.view(b, -1, w * h).permute(0, 2, 1)\n",
        "        b, c, h, w = G.size()\n",
        "        G = G.view(b, -1, w * h)\n",
        "        S = torch.bmm(F, G)\n",
        "        S = self.sm(S)\n",
        "        b, c, h, w = H.size()\n",
        "        H = H.view(b, -1, w * h)\n",
        "        O = torch.bmm(H, S.permute(0, 2, 1))\n",
        "        b, c, h, w = content.size()\n",
        "        O = O.view(b, c, h, w)\n",
        "        O = self.out_conv(O)\n",
        "        O += content\n",
        "        return O\n",
        "\n",
        "class Transform(nn.Module):\n",
        "    def __init__(self, in_planes):\n",
        "        super(Transform, self).__init__()\n",
        "        self.sanet4_1 = SANet(in_planes = in_planes)\n",
        "        self.sanet5_1 = SANet(in_planes = in_planes)\n",
        "        self.upsample5_1 = nn.Upsample(scale_factor=2, mode='nearest')\n",
        "        self.merge_conv_pad = nn.ReflectionPad2d((1, 1, 1, 1))\n",
        "        self.merge_conv = nn.Conv2d(in_planes, in_planes, (3, 3))\n",
        "    def forward(self, content4_1, style4_1, content5_1, style5_1):\n",
        "        return self.merge_conv(self.merge_conv_pad(self.sanet4_1(content4_1, style4_1) + self.upsample5_1(self.sanet5_1(content5_1, style5_1))))\n",
        "\n",
        "class Net(nn.Module):\n",
        "    def __init__(self, encoder, decoder, start_iter):\n",
        "        super(Net, self).__init__()\n",
        "        enc_layers = list(encoder.children())\n",
        "        self.enc_1 = nn.Sequential(*enc_layers[:4])  # input -> relu1_1\n",
        "        self.enc_2 = nn.Sequential(*enc_layers[4:11])  # relu1_1 -> relu2_1\n",
        "        self.enc_3 = nn.Sequential(*enc_layers[11:18])  # relu2_1 -> relu3_1\n",
        "        self.enc_4 = nn.Sequential(*enc_layers[18:31])  # relu3_1 -> relu4_1\n",
        "        self.enc_5 = nn.Sequential(*enc_layers[31:44])  # relu4_1 -> relu5_1\n",
        "        #transform\n",
        "        self.transform = Transform(in_planes = 512)\n",
        "        self.decoder = decoder\n",
        "        if(start_iter > 0):\n",
        "            self.transform.load_state_dict(torch.load('transformer_iter_' + str(start_iter) + '.pth'))\n",
        "            self.decoder.load_state_dict(torch.load('decoder_iter_' + str(start_iter) + '.pth'))\n",
        "        self.mse_loss = nn.MSELoss()\n",
        "        # fix the encoder\n",
        "        for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4', 'enc_5']:\n",
        "            for param in getattr(self, name).parameters():\n",
        "                param.requires_grad = False\n",
        "\n",
        "    # extract relu1_1, relu2_1, relu3_1, relu4_1, relu5_1 from input image\n",
        "    def encode_with_intermediate(self, input):\n",
        "        results = [input]\n",
        "        for i in range(5):\n",
        "            func = getattr(self, 'enc_{:d}'.format(i + 1))\n",
        "            results.append(func(results[-1]))\n",
        "        return results[1:]\n",
        "\n",
        "    def calc_content_loss(self, input, target, norm = False):\n",
        "        if(norm == False):\n",
        "          return self.mse_loss(input, target)\n",
        "        else:\n",
        "          return self.mse_loss(mean_variance_norm(input), mean_variance_norm(target))\n",
        "\n",
        "    def calc_style_loss(self, input, target):\n",
        "        input_mean, input_std = calc_mean_std(input)\n",
        "        target_mean, target_std = calc_mean_std(target)\n",
        "        return self.mse_loss(input_mean, target_mean) + \\\n",
        "               self.mse_loss(input_std, target_std)\n",
        "    \n",
        "    def forward(self, content, style):\n",
        "        style_feats = self.encode_with_intermediate(style)\n",
        "        content_feats = self.encode_with_intermediate(content)\n",
        "        stylized = self.transform(content_feats[3], style_feats[3], content_feats[4], style_feats[4])\n",
        "        g_t = self.decoder(stylized)\n",
        "        g_t_feats = self.encode_with_intermediate(g_t)\n",
        "        loss_c = self.calc_content_loss(g_t_feats[3], content_feats[3], norm = True) + self.calc_content_loss(g_t_feats[4], content_feats[4], norm = True)\n",
        "        loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])\n",
        "        for i in range(1, 5):\n",
        "            loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])\n",
        "        \"\"\"IDENTITY LOSSES\"\"\"\n",
        "        Icc = self.decoder(self.transform(content_feats[3], content_feats[3], content_feats[4], content_feats[4]))\n",
        "        Iss = self.decoder(self.transform(style_feats[3], style_feats[3], style_feats[4], style_feats[4]))\n",
        "        l_identity1 = self.calc_content_loss(Icc, content) + self.calc_content_loss(Iss, style)\n",
        "        Fcc = self.encode_with_intermediate(Icc)\n",
        "        Fss = self.encode_with_intermediate(Iss)\n",
        "        l_identity2 = self.calc_content_loss(Fcc[0], content_feats[0]) + self.calc_content_loss(Fss[0], style_feats[0])\n",
        "        for i in range(1, 5):\n",
        "            l_identity2 += self.calc_content_loss(Fcc[i], content_feats[i]) + self.calc_content_loss(Fss[i], style_feats[i])\n",
        "        return loss_c, loss_s, l_identity1, l_identity2\n",
        "\n",
        "import numpy as np\n",
        "from torch.utils import data\n",
        "\n",
        "\n",
        "def InfiniteSampler(n):\n",
        "    # i = 0\n",
        "    i = n - 1\n",
        "    order = np.random.permutation(n)\n",
        "    while True:\n",
        "        yield order[i]\n",
        "        i += 1\n",
        "        if i >= n:\n",
        "            np.random.seed()\n",
        "            order = np.random.permutation(n)\n",
        "            i = 0\n",
        "\n",
        "\n",
        "class InfiniteSamplerWrapper(data.sampler.Sampler):\n",
        "    def __init__(self, data_source):\n",
        "        self.num_samples = len(data_source)\n",
        "\n",
        "    def __iter__(self):\n",
        "        return iter(InfiniteSampler(self.num_samples))\n",
        "\n",
        "    def __len__(self):\n",
        "        return 2 ** 31\n",
        "\n",
        "import argparse\n",
        "import os\n",
        "import torch\n",
        "import torch.backends.cudnn as cudnn\n",
        "import torch.nn as nn\n",
        "import torch.utils.data as data\n",
        "from PIL import Image\n",
        "from PIL import ImageFile\n",
        "from tensorboardX import SummaryWriter\n",
        "from torchvision import transforms\n",
        "from tqdm import tqdm\n",
        "\n",
        "cudnn.benchmark = True\n",
        "Image.MAX_IMAGE_PIXELS = None  # Disable DecompressionBombError\n",
        "ImageFile.LOAD_TRUNCATED_IMAGES = True  # Disable OSError: image file is truncated\n",
        "\n",
        "\n",
        "def train_transform():\n",
        "    transform_list = [\n",
        "        transforms.Resize(size=(512, 512)),\n",
        "        transforms.RandomCrop(256),\n",
        "        transforms.ToTensor()\n",
        "    ]\n",
        "    return transforms.Compose(transform_list)\n",
        "\n",
        "\n",
        "class FlatFolderDataset(data.Dataset):\n",
        "    def __init__(self, root, transform):\n",
        "        super(FlatFolderDataset, self).__init__()\n",
        "        self.root = root\n",
        "        self.paths = os.listdir(self.root)\n",
        "        self.transform = transform\n",
        "\n",
        "    def __getitem__(self, index):\n",
        "        path = self.paths[index]\n",
        "        img = Image.open(os.path.join(self.root, path)).convert('RGB')\n",
        "        img = self.transform(img)\n",
        "        return img\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.paths)\n",
        "\n",
        "    def name(self):\n",
        "        return 'FlatFolderDataset'\n",
        "\n",
        "\n",
        "def adjust_learning_rate(optimizer, iteration_count):\n",
        "    \"\"\"Imitating the original implementation\"\"\"\n",
        "    lr = args.lr / (1.0 + args.lr_decay * iteration_count)\n",
        "    for param_group in optimizer.param_groups:\n",
        "        param_group['lr'] = lr\n",
        "\n",
        "\n",
        "parser = argparse.ArgumentParser()\n",
        "# Basic options\n",
        "parser.add_argument('--content_dir', type=str, default='./train2014',\n",
        "                    help='Directory path to a batch of content images')\n",
        "parser.add_argument('--style_dir', type=str, default='./train',\n",
        "                    help='Directory path to a batch of style images')\n",
        "parser.add_argument('--vgg', type=str, default='./vgg_normalised.pth')\n",
        "\n",
        "# training options\n",
        "parser.add_argument('--save_dir', default='./experiments',\n",
        "                    help='Directory to save the model')\n",
        "parser.add_argument('--log_dir', default='./logs',\n",
        "                    help='Directory to save the log')\n",
        "parser.add_argument('--lr', type=float, default=1e-4)\n",
        "parser.add_argument('--lr_decay', type=float, default=5e-5)\n",
        "parser.add_argument('--max_iter', type=int, default=160000)\n",
        "parser.add_argument('--batch_size', type=int, default=5)\n",
        "parser.add_argument('--style_weight', type=float, default=3.0)\n",
        "parser.add_argument('--content_weight', type=float, default=1.0)\n",
        "parser.add_argument('--n_threads', type=int, default=16)\n",
        "parser.add_argument('--save_model_interval', type=int, default=1000)\n",
        "parser.add_argument('--start_iter', type=float, default=0)\n",
        "args = parser.parse_args('')\n",
        "\n",
        "device = torch.device('cuda')\n",
        "\n",
        "decoder = decoder\n",
        "vgg = vgg\n",
        "\n",
        "vgg.load_state_dict(torch.load(args.vgg))\n",
        "vgg = nn.Sequential(*list(vgg.children())[:44])\n",
        "network = Net(vgg, decoder, args.start_iter)\n",
        "network.train()\n",
        "network.to(device)\n",
        "\n",
        "content_tf = train_transform()\n",
        "style_tf = train_transform()\n",
        "\n",
        "content_dataset = FlatFolderDataset(args.content_dir, content_tf)\n",
        "style_dataset = FlatFolderDataset(args.style_dir, style_tf)\n",
        "\n",
        "content_iter = iter(data.DataLoader(\n",
        "    content_dataset, batch_size=args.batch_size,\n",
        "    sampler=InfiniteSamplerWrapper(content_dataset),\n",
        "    num_workers=args.n_threads))\n",
        "style_iter = iter(data.DataLoader(\n",
        "    style_dataset, batch_size=args.batch_size,\n",
        "    sampler=InfiniteSamplerWrapper(style_dataset),\n",
        "    num_workers=args.n_threads))\n",
        "\n",
        "optimizer = torch.optim.Adam([\n",
        "                              {'params': network.decoder.parameters()},\n",
        "                              {'params': network.transform.parameters()}], lr=args.lr)\n",
        "\n",
        "if(args.start_iter > 0):\n",
        "    optimizer.load_state_dict(torch.load('optimizer_iter_' + str(args.start_iter) + '.pth'))\n",
        "\n",
        "for i in tqdm(range(args.start_iter, args.max_iter)):\n",
        "    adjust_learning_rate(optimizer, iteration_count=i)\n",
        "    content_images = next(content_iter).to(device)\n",
        "    style_images = next(style_iter).to(device)\n",
        "    loss_c, loss_s, l_identity1, l_identity2 = network(content_images, style_images)\n",
        "    loss_c = args.content_weight * loss_c\n",
        "    loss_s = args.style_weight * loss_s\n",
        "    loss = loss_c + loss_s + l_identity1 * 50 + l_identity2 * 1\n",
        "\n",
        "    optimizer.zero_grad()\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "\n",
        "    if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter:\n",
        "        state_dict = decoder.state_dict()\n",
        "        for key in state_dict.keys():\n",
        "            state_dict[key] = state_dict[key].to(torch.device('cpu'))\n",
        "        torch.save(state_dict,\n",
        "                   '{:s}/decoder_iter_{:d}.pth'.format(args.save_dir,\n",
        "                                                           i + 1))\n",
        "        state_dict = network.transform.state_dict()\n",
        "        for key in state_dict.keys():\n",
        "            state_dict[key] = state_dict[key].to(torch.device('cpu'))\n",
        "        torch.save(state_dict,\n",
        "                   '{:s}/transformer_iter_{:d}.pth'.format(args.save_dir,\n",
        "                                                           i + 1))\n",
        "        state_dict = optimizer.state_dict()\n",
        "        torch.save(state_dict,\n",
        "                   '{:s}/optimizer_iter_{:d}.pth'.format(args.save_dir,\n",
        "                                                           i + 1))\n",
        "writer.close()"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}